""" Norm Layer Factory

Create norm modules by string (to mirror create_act and creat_norm-act fns)

Copyright 2022 Ross Wightman
"""
import types
import functools

import torch.nn as nn

from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d

_NORM_MAP = dict(
    batchnorm=nn.BatchNorm2d,
    batchnorm2d=nn.BatchNorm2d,
    batchnorm1d=nn.BatchNorm1d,
    groupnorm=GroupNorm,
    groupnorm1=GroupNorm1,
    layernorm=LayerNorm,
    layernorm2d=LayerNorm2d,
)
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}


def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs):
    layer = get_norm_layer(layer_name, act_layer=act_layer)
    layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
    return layer_instance


def get_norm_layer(norm_layer):
    assert isinstance(norm_layer, (type, str,  types.FunctionType, functools.partial))
    norm_kwargs = {}

    # unbind partial fn, so args can be rebound later
    if isinstance(norm_layer, functools.partial):
        norm_kwargs.update(norm_layer.keywords)
        norm_layer = norm_layer.func

    if isinstance(norm_layer, str):
        layer_name = norm_layer.replace('_', '')
        norm_layer = _NORM_MAP.get(layer_name, None)
    elif norm_layer in _NORM_TYPES:
        norm_layer = norm_layer
    elif isinstance(norm_layer, types.FunctionType):
        # if function type, assume it is a lambda/fn that creates a norm layer
        norm_layer = norm_layer
    else:
        type_name = norm_layer.__name__.lower().replace('_', '')
        norm_layer = _NORM_MAP.get(type_name, None)
        assert norm_layer is not None, f"No equivalent norm layer for {type_name}"

    if norm_kwargs:
        norm_layer = functools.partial(norm_layer, **norm_kwargs)  # bind/rebind args
    return norm_layer
